# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Loading pretrained models.
"""

import logging
from pathlib import Path
import typing as tp

# from dora.log import fatal

import logging

from diffq import DiffQuantizer
import torch.hub

from .model import Demucs
from .tasnet_v2 import ConvTasNet
from .utils import set_state

from .hdemucs import HDemucs
from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError  # noqa

logger = logging.getLogger(__name__)
ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/mdx_final/"
REMOTE_ROOT = Path(__file__).parent / "remote"

SOURCES = ["drums", "bass", "other", "vocals"]


def demucs_unittest():
    model = HDemucs(channels=4, sources=SOURCES)
    return model


def add_model_flags(parser):
    group = parser.add_mutually_exclusive_group(required=False)
    group.add_argument("-s", "--sig", help="Locally trained XP signature.")
    group.add_argument("-n", "--name", default="mdx_extra_q", help="Pretrained model name or signature. Default is mdx_extra_q.")
    parser.add_argument("--repo", type=Path, help="Folder containing all pre-trained models for use with -n.")


def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]:
    root: str = ""
    models: tp.Dict[str, str] = {}
    for line in remote_file_list.read_text().split("\n"):
        line = line.strip()
        if line.startswith("#"):
            continue
        elif line.startswith("root:"):
            root = line.split(":", 1)[1].strip()
        else:
            sig = line.split("-", 1)[0]
            assert sig not in models
            models[sig] = ROOT_URL + root + line
    return models


def get_model(name: str, repo: tp.Optional[Path] = None):
    """`name` must be a bag of models name or a pretrained signature
    from the remote AWS model repo or the specified local repo if `repo` is not None.
    """
    if name == "demucs_unittest":
        return demucs_unittest()
    model_repo: ModelOnlyRepo
    if repo is None:
        models = _parse_remote_files(REMOTE_ROOT / "files.txt")
        model_repo = RemoteRepo(models)
        bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
    else:
        if not repo.is_dir():
            fatal(f"{repo} must exist and be a directory.")
        model_repo = LocalRepo(repo)
        bag_repo = BagOnlyRepo(repo, model_repo)
    any_repo = AnyModelRepo(model_repo, bag_repo)
    model = any_repo.get_model(name)
    model.eval()
    return model


def get_model_from_args(args):
    """
    Load local model package or pre-trained model.
    """
    return get_model(name=args.name, repo=args.repo)


logger = logging.getLogger(__name__)
ROOT = "https://dl.fbaipublicfiles.com/demucs/v3.0/"

PRETRAINED_MODELS = {
    "demucs": "e07c671f",
    "demucs48_hq": "28a1282c",
    "demucs_extra": "3646af93",
    "demucs_quantized": "07afea75",
    "tasnet": "beb46fac",
    "tasnet_extra": "df3777b2",
    "demucs_unittest": "09ebc15f",
}

SOURCES = ["drums", "bass", "other", "vocals"]


def get_url(name):
    sig = PRETRAINED_MODELS[name]
    return ROOT + name + "-" + sig[:8] + ".th"


def is_pretrained(name):
    return name in PRETRAINED_MODELS


def load_pretrained(name):
    if name == "demucs":
        return demucs(pretrained=True)
    elif name == "demucs48_hq":
        return demucs(pretrained=True, hq=True, channels=48)
    elif name == "demucs_extra":
        return demucs(pretrained=True, extra=True)
    elif name == "demucs_quantized":
        return demucs(pretrained=True, quantized=True)
    elif name == "demucs_unittest":
        return demucs_unittest(pretrained=True)
    elif name == "tasnet":
        return tasnet(pretrained=True)
    elif name == "tasnet_extra":
        return tasnet(pretrained=True, extra=True)
    else:
        raise ValueError(f"Invalid pretrained name {name}")


def _load_state(name, model, quantizer=None):
    url = get_url(name)
    state = torch.hub.load_state_dict_from_url(url, map_location="cpu", check_hash=True)
    set_state(model, quantizer, state)
    if quantizer:
        quantizer.detach()


def demucs_unittest(pretrained=True):
    model = Demucs(channels=4, sources=SOURCES)
    if pretrained:
        _load_state("demucs_unittest", model)
    return model


def demucs(pretrained=True, extra=False, quantized=False, hq=False, channels=64):
    if not pretrained and (extra or quantized or hq):
        raise ValueError("if extra or quantized is True, pretrained must be True.")
    model = Demucs(sources=SOURCES, channels=channels)
    if pretrained:
        name = "demucs"
        if channels != 64:
            name += str(channels)
        quantizer = None
        if sum([extra, quantized, hq]) > 1:
            raise ValueError("Only one of extra, quantized, hq, can be True.")
        if quantized:
            quantizer = DiffQuantizer(model, group_size=8, min_size=1)
            name += "_quantized"
        if extra:
            name += "_extra"
        if hq:
            name += "_hq"
        _load_state(name, model, quantizer)
    return model


def tasnet(pretrained=True, extra=False):
    if not pretrained and extra:
        raise ValueError("if extra is True, pretrained must be True.")
    model = ConvTasNet(X=10, sources=SOURCES)
    if pretrained:
        name = "tasnet"
        if extra:
            name = "tasnet_extra"
        _load_state(name, model)
    return model
